'''
Distributed under the MIT License(MIT)

Copyright(c) 2023 Jihua Zou EMail: ghuazo@qq.com QQ:137336521

Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files(the "Software"), to deal in the
Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and /or sell copies
of the Software, and to permit persons to whom the Software is furnished
to do so, subject to the following conditions :

The above copyright notice and this permission notice shall be included
in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE AUTHORS
OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''
import json
import os
import re
import akshare as ak
import streamlit as st
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from sklearn.linear_model import LinearRegression
from apollo_client import ApolloClient
from data_center import DataCenter

data = None
apollo = None

services = [
    {
        "id":"gxzjh",
        "name":"国信实盘(127.0.0.1:19527)",
        "host":"127.0.0.1",
        "port":19527
        
    },
    {
        "id":"simnow",
        "name":"SIMNOW(127.0.0.1:19528)",
        "host":"127.0.0.1",
        "port":19528  
    }
]


    # 获取期货K线数据并处理
@st.cache_data(ttl=300, show_spinner="正在获取期货数据...")
def get_kline_data(code):
    try:
        symbol = code.split('.')[-1].upper()
        result = ak.futures_zh_daily_sina(symbol=symbol)
        if not result.empty:
            result.rename(columns={
                '日期': 'date',
                '开盘': 'open',
                '最高': 'high',
                '最低': 'low',
                '收盘': 'close'
            }, inplace=True)
            result['date'] = pd.to_datetime(result['date'])
            result.set_index('date', inplace=True)
            result.sort_index(inplace=True)
            return result.tail(30)  # 仅保留最近30条数据
    except Exception as e:
        st.error(f"数据获取失败: {str(e)}")
    return pd.DataFrame()

def calculate_regression(data, n_days):
    if len(data) < n_days:
        n_days = len(data)
    data_subset = data.tail(n_days)

    # 训练集X为相对索引（0到n_days-1）
    X = np.arange(n_days).reshape(-1, 1)
    y = data_subset['close'].values
    model = LinearRegression().fit(X, y)

    # 预测时使用相同范围的X（0到n_days-1）
    trend_line = model.predict(X)

    # 返回斜率、对应的日期索引和趋势线
    return model.coef_[0], data_subset.index, trend_line


# 使用示例合约代码
contract_data = None
trading_contract = None

# 修改所有同步按钮的回调函数
def sync_attribute_callback(params):
    def callback():
        with st.spinner("正在同步配置..."):
            apollo.sync_attribute(trading_contract, params)
            st.session_state.target_data = apollo.get_runinfo(trading_contract)
    return callback

def sync_position_callback(params):
    def callback():
        with st.spinner("正在同步仓位..."):
            apollo.sync_position(trading_contract, params,st.session_state.rundata['use_force'])
            st.session_state.target_data = apollo.get_runinfo(trading_contract)
    return callback
# 页面布局
st.set_page_config(layout="wide")
#st.title("期货智能交易系统")

# 侧边栏控件
with st.sidebar:
    st.header("策略控制台")
    selected_name = st.selectbox("选择交易端", [service["name"] for service in services],index=0)
    selected_service = next((s for s in services if s["name"] == selected_name), None)
    if selected_service:
        apollo = ApolloClient(selected_service['host'],selected_service['port'])
        data = DataCenter(selected_service['id'])
    print('service : ',selected_service['id'],selected_service["name"])
    contracts = apollo.get_contract()
    trading_contract = st.selectbox("选择合约", contracts,index=0)
    contract_data = get_kline_data(trading_contract)
    if 'current_contract' not in st.session_state or st.session_state.current_contract != trading_contract:
        st.session_state.target_data = apollo.get_runinfo(trading_contract)
        st.session_state.rundata = data.get_strategy_data(trading_contract)
        st.session_state.current_contract = trading_contract
    def update_n_days():
        st.session_state.rundata['n_days'] = st.session_state.n_days_slider
    st.slider("趋势计算天数", 3, 15, st.session_state.rundata['n_days'],key='n_days_slider',on_change=update_n_days,help="选择用于计算趋势线和止损线的最近交易日数量")
    def update_price_range():
        st.session_state.rundata['price_min'] = st.session_state.price_range_slider[0]
        st.session_state.rundata['price_max'] = st.session_state.price_range_slider[1]
    st.slider("可交易区间", -10.0, 10.0,(st.session_state.rundata['price_min'],st.session_state.rundata['price_max']),key='price_range_slider',on_change=update_price_range,help="设置止盈止损线距离趋势线的百分比")
    if st.button('刷新数据', help="点击刷新",use_container_width=True):
        st.session_state.target_data = apollo.get_runinfo(trading_contract)
        st.rerun()
    # 使用 tabs 组件创建互斥导航     
    tab_manual,tab_grid, tap_setting = st.tabs(["🕹️ 手动操作","📊 网格托管", "🔗 系统设置"])
    with tab_manual:
        st.subheader("期货账户操作")
        st.button('📈 满仓开多', help="点击满仓开多",use_container_width=True,disabled=st.session_state.rundata['auto_state'],on_click=sync_position_callback(1))
        st.button('📉 满仓开空', help="点击满仓开空",use_container_width=True,disabled=st.session_state.rundata['auto_state'],on_click=sync_position_callback(-1))
        st.button('🛑 全部平仓', help="点击全部平仓",use_container_width=True,disabled=st.session_state.rundata['auto_state'],on_click=sync_position_callback(0))
        with st.expander("精确控制持仓",True):
            position_points = st.slider("调整仓位", -st.session_state.target_data['max_position'], st.session_state.target_data['max_position'],st.session_state.target_data['position'],help="设置当前持仓比例",disabled=st.session_state.rundata['auto_state'])
            ratio = position_points/float(st.session_state.target_data['max_position'])
            st.button('📩 确定同步', help="点击同步当前持仓",use_container_width=True,disabled=st.session_state.rundata['auto_state'],on_click=sync_position_callback(ratio))
            
    with tab_grid:
        st.subheader("网格策略设置")
        def update_grid_base():
            st.session_state.rundata['grid_base'] = st.session_state.grid_base_slider
        st.slider("网格基准", -10.0, 10.0, st.session_state.rundata['grid_base'],key='grid_base_slider',on_change=update_grid_base,disabled=not st.session_state.rundata['auto_state'])
        def update_grid_density():
            st.session_state.rundata['grid_density'] = st.session_state.grid_density_slider
        st.slider("网格密度", 0.0, 1.0, st.session_state.rundata['grid_density'],key='grid_density_slider',on_change=update_grid_density,disabled=not st.session_state.rundata['auto_state'])
        def update_grid_size():
            st.session_state.rundata['grid_size'] = st.session_state.grid_size_slider
        st.slider("网格大小", 1, 20,st.session_state.rundata['grid_size'],key='grid_size_slider',on_change=update_grid_size,disabled=not st.session_state.rundata['auto_state'])
        
        
    with tap_setting:
        st.subheader("系统设置")
        def update_auto_state():
            st.session_state.rundata['auto_state'] = st.session_state.auto_state_toggle
        
        st.toggle("启用网格托管", 
             value=st.session_state.rundata['auto_state'], 
             key="auto_state_toggle",
             on_change=update_auto_state)
        
        def update_use_force():
            st.session_state.rundata['use_force'] = st.session_state.use_force_toggle
        st.toggle("强制跳过风控", 
             value=st.session_state.rundata['use_force'],
             key="use_force_toggle",
             on_change=update_use_force)
    
    data.set_strategy_data(trading_contract,st.session_state.rundata)

    # 固定在侧边栏底部
    st.markdown("---")

    st.markdown("""
    <div style="margin-top: auto">
        <p style="text-align: center">©80后程序员 | QQ:3844792568 版权所有</p>
        <p style="text-align: center"><a href="https://gitee.com/lightning-trader/lightning-rebalance">访问项目主页</a></p>
    </div>
    """, unsafe_allow_html=True)
# 主显示区域

    
if not contract_data.empty:
    # 显示关键指标
    
    if st.session_state.target_data:
        col1, col2, col3, col4, col5 = st.columns(5)
        with col1:
            if st.session_state.target_data['sleep']:
                st.metric("持仓", f"{st.session_state.target_data['position']}/{st.session_state.target_data['max_position']}",delta="dirty")
            else:
                st.metric("持仓", f"{st.session_state.target_data['position']}/{st.session_state.target_data['max_position']}")
        with col2:
            if st.session_state.target_data['current_price'] > .0:
                st.metric("价格上限", f"{st.session_state.target_data['max_price']:.2f}",f"{abs(st.session_state.target_data['max_price']-st.session_state.target_data['current_price']):.2f}")
            else:
                st.metric("价格上限", f"{st.session_state.target_data['max_price']:.2f}")   
        with col3:
            if st.session_state.target_data['current_price'] > .0:
                st.metric("价格下限", f"{st.session_state.target_data['min_price']:.2f}",f"{abs(st.session_state.target_data['current_price']-st.session_state.target_data['min_price']):.2f}")
            else:
                st.metric("价格下限", f"{st.session_state.target_data['min_price']:.2f}")    
        with col4:
            if st.session_state.rundata['auto_state']:
                st.metric("网格数据", f"{(st.session_state.target_data['grid_price']-st.session_state.target_data['grid_density']):.2f}~{(st.session_state.target_data['grid_price']+st.session_state.target_data['grid_density']):.2f}",st.session_state.target_data['grid_density'])
            else:
                st.metric("网格数据", "0.00~0.00","未启用")
        with col5:
            st.metric("日成长", f"{st.session_state.target_data['trand_slope']:.2f}")
    # 计算趋势指标
    slope, trend_dates, trend_line = calculate_regression(contract_data, st.session_state.rundata['n_days'])
        
    # 生成止盈止损线（仅最后n_days有效）
    min_range_line = trend_line + st.session_state.rundata['price_min'] * trend_line/100
    max_range_line = trend_line + st.session_state.rundata['price_max'] * trend_line/100
    grid_base_line = trend_line + st.session_state.rundata['grid_base'] * trend_line/100
            
    # 创建K线图
    fig = go.Figure()
    fig.add_trace(go.Candlestick(
        x=contract_data.index,
        open=contract_data['open'],
        high=contract_data['high'],
        low=contract_data['low'],
        close=contract_data['close'],
        increasing_line_color='red',    # 上涨为红色
        decreasing_line_color='green',   # 下跌为绿色
        name='K线'
    ))
        
    # 添加趋势线（仅显示有效计算区间）
    fig.add_trace(go.Scatter(
        x=trend_dates,
        y=trend_line,
        mode='lines',
        name=f'{st.session_state.rundata['n_days']}日趋势线',
        line=dict(color='blue', width=2)
    ))
        
    # 添加边界线
    fig.add_trace(go.Scatter(
        x=trend_dates,
        y=min_range_line,
        mode='lines',
        name='下边界线',
        line=dict(color='red', width=2, dash='dot')
    ))
        
    # 添加边界线
    fig.add_trace(go.Scatter(
        x=trend_dates,
        y=max_range_line,
        mode='lines',
        name='上边界线',
        line=dict(color='red', width=2, dash='dot')
    ))
    if st.session_state.rundata['auto_state']:
        # 添加网格基准线
        fig.add_trace(go.Scatter(
            x=trend_dates,
            y=grid_base_line,
            mode='lines',
            name='基准线',
            line=dict(color='purple', width=2, dash='dash')
        )) 
        gtid_line_color = 'purple'
        for i in range(st.session_state.rundata['grid_size']):
            # 添加网格基准线
            grid_line_short = grid_base_line + st.session_state.rundata['grid_density'] * ( i + 1 ) * grid_base_line/100.0
            if ((min_range_line < grid_line_short) & (grid_line_short < max_range_line)).all():
                fig.add_trace(go.Scatter(
                    x=trend_dates,
                    y=grid_line_short,
                    mode='lines',
                    name='网格线',
                    line=dict(color='green', width=1, dash='dashdot')
                ))
        for i in range(st.session_state.rundata['grid_size']):
            grid_line_long = grid_base_line - st.session_state.rundata['grid_density'] * ( i + 1 ) * grid_base_line/100.0
            if ((min_range_line < grid_line_long) & (grid_line_long < max_range_line)).all():
                fig.add_trace(go.Scatter(
                    x=trend_dates,
                    y=grid_line_long,
                    mode='lines',
                    name='网格线',
                    line=dict(color='red', width=1, dash='dashdot')
                ))
            
                 
            
    fig.update_layout(
        height=600,
        xaxis_rangeslider_visible=False,
        xaxis_title='日期',
        yaxis_title='价格',
        showlegend=True,
        xaxis=dict(
            type='date',
            range=[contract_data.index.min(), contract_data.index.max()]
        )
    )
    fig.update_xaxes(
        type='category',
        tickmode='array',
        tickvals=contract_data.index,
        ticktext=contract_data.index.strftime('%Y%m%d')
    )

    st.plotly_chart(fig, use_container_width=True)
        
    # 显示关键指标
    latest_close = contract_data['close'].iloc[-1] 
    expected_price = contract_data['close'].iloc[-1] + slope
    current_min_price = min_range_line[-1] + slope
    current_max_price = max_range_line[-1] + slope
    current_grid_base = grid_base_line[-1] + slope


    act1, act2, act3, act4 = st.columns(4)
    with act1:
        st.metric("当前趋势", f"{latest_close:.2f}->{expected_price:.2f}",delta=f"{slope:.2f}点/日")
        st.button('🌵 同步成长数据', help="点击提交当前策略配置",use_container_width=True,on_click=sync_attribute_callback({
            'trand_slope':slope
        }))
    with act2:
        st.metric("最高价位", f"{current_max_price:.2f}",delta=f"间距: {abs(current_max_price - expected_price):.2f}")
        st.button('🍖 同步最高价位', help="点击提交当前策略配置",use_container_width=True,on_click=sync_attribute_callback({
            'max_price':current_max_price
        }))
    with act3:
        st.metric("最低价位", f"{current_min_price:.2f}",delta=f"间距: {abs(expected_price - current_min_price):.2f}")
        st.button('🥧 同步最低价位', help="点击提交当前策略配置",use_container_width=True,on_click=sync_attribute_callback({
            'min_price':current_min_price
        }))
    with act4:
        st.metric("网格配置", f"{current_grid_base:.2f}x{st.session_state.rundata['grid_size']}",delta=f"密度: {st.session_state.rundata['grid_density']*current_grid_base/100.0:.2f}")
        st.button('📊 同步网格配置', help="点击提交当前策略配置",use_container_width=True,on_click=sync_attribute_callback({
            'grid_base':current_grid_base,
            'grid_density':st.session_state.rundata['grid_density']*current_grid_base/100.0,
            'grid_size':st.session_state.rundata['grid_size'],
            'use_grid':st.session_state.rundata['auto_state']
        }))
    
    st.button('🚀 一键同步所有属性', help="点击提交当前策略配置",use_container_width=True,on_click=sync_attribute_callback({
        'trand_slope':slope,
        'max_price':current_max_price,
        'min_price':current_min_price,
        'grid_base':current_grid_base,
        'grid_density':st.session_state.rundata['grid_density']*current_grid_base/100.0,
        'grid_size':st.session_state.rundata['grid_size'],
        'use_grid':st.session_state.rundata['auto_state']
    }))
        
    
    # 显示原始数据
    with st.expander("查看完整历史数据"):
        st.dataframe(contract_data.sort_index(ascending=False))
else:
    st.warning("该合约暂无有效数据")
